Skip to content

[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052

Open
allenphilipj wants to merge 3 commits into
NVIDIA:mainfrom
allenphilipj:codex-grouped-linear-fp8-cudagraph-skip
Open

[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052
allenphilipj wants to merge 3 commits into
NVIDIA:mainfrom
allenphilipj:codex-grouped-linear-fp8-cudagraph-skip

Conversation

@allenphilipj
Copy link
Copy Markdown

@allenphilipj allenphilipj commented May 28, 2026

Summary:

  • Propagate the FP8 graph-capture skip_fp8_weight_update tensor through GroupedLinear.
  • Align GroupedLinear graph-capture handling with Linear, LayerNormLinear, and LayerNormMLP.

Validation:

  • git diff --check
  • python3 -m py_compile transformer_engine/pytorch/module/grouped_linear.py
  • Not run: GPU test suite not available in this local environment.

Fixes #3051

@allenphilipj allenphilipj requested a review from ksivaman as a code owner May 28, 2026 12:36
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 28, 2026
@allenphilipj allenphilipj force-pushed the codex-grouped-linear-fp8-cudagraph-skip branch from 937ef34 to 80304fa Compare May 28, 2026 12:40
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 28, 2026

Greptile Summary

This PR fixes GroupedLinear to correctly propagate the FP8 graph-capture skip_fp8_weight_update tensor, aligning it with Linear, LayerNormLinear, and LayerNormMLP. Previously the field was hard-coded to None, causing the CUDA graph to always re-quantize FP8 weights even when the graph signals that the update should be skipped.

  • Adds the same FP8GlobalStateManager.fp8_graph_capturing() guard block that the other three modules already use, retrieving skip_fp8_weight_update_tensor and setting is_first_microbatch = False when inside a graph-capture context.
  • Replaces the hard-coded None with skip_fp8_weight_update in the non_tensor_args tuple passed to _GroupedLinear.forward, so the tensor flows through to quantize_weight via _prepare_weights_for_grouped_tensor_gemm.

Confidence Score: 5/5

The change is a minimal, targeted fix — two lines are added to retrieve the graph-capture tensor and one None literal is replaced. The new code is structurally identical to the already-tested pattern in Linear, LayerNormLinear, and LayerNormMLP.

The new code block is a direct copy of the well-established pattern used by the other three TE linear modules. The only variable affected flows unchanged into an existing parameter slot; no control flow, dtypes, or tensor shapes are altered. The fix is additive and isolated to the FP8 graph-capture path.

No files require special attention; the single changed file is straightforward and self-contained.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Propagates the FP8 graph-capture skip_fp8_weight_update tensor in GroupedLinear.forward(), mirroring the identical pattern already present in Linear, LayerNormLinear, and LayerNormMLP. Previously the field was hard-coded to None.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant GroupedLinear.forward
    participant FP8GlobalStateManager
    participant _GroupedLinear.forward
    participant _prepare_weights_for_grouped_tensor_gemm
    participant quantize_weight

    Caller->>GroupedLinear.forward: forward(inp, m_splits, is_first_microbatch)
    GroupedLinear.forward->>FP8GlobalStateManager: fp8_graph_capturing()?
    alt FP8 graph capture active
        FP8GlobalStateManager-->>GroupedLinear.forward: skip_fp8_weight_update tensor
        GroupedLinear.forward->>GroupedLinear.forward: "is_first_microbatch = False"
    else Normal execution
        GroupedLinear.forward->>GroupedLinear.forward: "skip_fp8_weight_update = None"
    end
    GroupedLinear.forward->>GroupedLinear.forward: "cache_weight = (is_first_microbatch is not None)"
    GroupedLinear.forward->>_GroupedLinear.forward: non_tensor_args (includes skip_fp8_weight_update)
    _GroupedLinear.forward->>_prepare_weights_for_grouped_tensor_gemm: skip_fp8_weight_update
    _prepare_weights_for_grouped_tensor_gemm->>quantize_weight: "skip_update_flag=skip_fp8_weight_update"
    quantize_weight-->>_GroupedLinear.forward: (cached or freshly-quantized) FP8 weight
Loading

Reviews (12): Last reviewed commit: "Merge branch 'main' into codex-grouped-l..." | Re-trigger Greptile

Comment thread tests/pytorch/test_cuda_graphs.py Outdated
@ksivaman
Copy link
Copy Markdown
Member

/te-ci pytorch

@allenphilipj allenphilipj force-pushed the codex-grouped-linear-fp8-cudagraph-skip branch from d7a4caa to 1890acf Compare June 2, 2026 16:38
@allenphilipj
Copy link
Copy Markdown
Author

@ksivaman I've rebased on the latest main & resolved the conflicts, would much appreciate a follow-up review.

Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is not necessary here, the change itself looks good.

Signed-off-by: allenphilipj <allenphilipj@users.noreply.github.com>
@allenphilipj allenphilipj force-pushed the codex-grouped-linear-fp8-cudagraph-skip branch from 5fd52b1 to e445000 Compare June 4, 2026 10:53
@allenphilipj
Copy link
Copy Markdown
Author

allenphilipj commented Jun 4, 2026

@ksivaman I dropped the test per feedback. The PR now contains only the GroupedLinear skip_fp8_weight_update propagation change.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 4, 2026

Want your agent to iterate on Greptile's feedback? Try greploops.

@ksivaman
Copy link
Copy Markdown
Member

ksivaman commented Jun 4, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PyTorch] GroupedLinear does not propagate skip_fp8_weight_update during FP8 CUDA graph capture

2 participants